import org.encog.Encog;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.util.csv.CSVFormat;
import org.encog.util.simple.EncogUtility;
import org.encog.util.simple.TrainingSetUtil;
import org.encog.mathutil.randomize.ConsistentRandomizer;
import java.io.*;

public class LearnThread extends Thread {
	Robot robot;
	boolean alive = true;
	private int maxDistance = 700;
    private int forwardSpeed = 5;
    private int turnSpeed = 3;
	private int poc = 0;
    private Writer out;
	
	LearnThread(Robot robot) {
		this.robot = robot;
		try {
        	out = new OutputStreamWriter(new FileOutputStream(new File("out.txt")));
    	} catch (Exception e){}
	}
	
	public void run() {
		while (alive) {
			learn();
		}
	}
	
	private boolean isNear(int dist) {
    	return dist > maxDistance;
    }
    
    private boolean leftSideIsNear() {
    	return (isNear(robot.sensors[0].getDistValue()) || isNear(robot.sensors[1].getDistValue()) || isNear(robot.sensors[2].getDistValue()));
    }
    
    private boolean rightSideIsNear() {
    	return (isNear(robot.sensors[3].getDistValue()) || isNear(robot.sensors[4].getDistValue()) || isNear(robot.sensors[5].getDistValue()));
    }
	
	public void learn() {
    	boolean left = leftSideIsNear(), right = rightSideIsNear();
    	
    	if (left) {
    		robot.turn(turnSpeed, -turnSpeed);
    		writeToFile(turnSpeed, -turnSpeed);
    		return;
    	}
    	if (right) {
    		robot.turn(-turnSpeed, turnSpeed);
    		writeToFile(-turnSpeed, turnSpeed);
    		return;
    	}
    	robot.forward();
    	writeToFile(forwardSpeed, forwardSpeed);
    }
    
    private void createNN() {
		BasicNetwork network = new BasicNetwork();
		MLDataSet trainingSet;
		int inputNeurons = 6, hiddenNeurons = 6, outputNeurons = 2, inputValue = 6, idealValue = 2;
		double nnError = 0.05;
		
		network.addLayer(new BasicLayer(new ActivationSigmoid(), true, inputNeurons));
		network.addLayer(new BasicLayer(new ActivationSigmoid(), true, hiddenNeurons));
		network.addLayer(new BasicLayer(new ActivationSigmoid(), true, outputNeurons));
		
		network.getStructure().finalizeStructure();
		network.reset();
		new ConsistentRandomizer(-1,1,500).randomize(network);
		
		trainingSet = TrainingSetUtil.loadCSVTOMemory(CSVFormat.ENGLISH, "out.txt", false, inputValue, idealValue);
		
		final Backpropagation train = new Backpropagation(network, trainingSet);
		train.fixFlatSpot(false);
		
		//~ EncogUtility.trainToError(network, trainingSet, nnError);
		
		int epoch = 1;
 
		do {
			train.iteration();
			System.out
					.println("Epoch #" + epoch + " Error:" + train.getError());
			epoch++;
		} while(train.getError() > nnError);
		
		robot.setNN(network);
	}
    
    private void writeToFile(int leftSpeed, int rightSpeed) {   	
		if (poc < 10000000) {
			if ((poc % 10000) == 0) {
				try {
					out.write(robot.normalizeSensorValue(robot.sensors[0].getDistValue()) + "," +  robot.normalizeSensorValue(robot.sensors[1].getDistValue()) + "," + robot.normalizeSensorValue(robot.sensors[2].getDistValue()) + "," + robot.normalizeSensorValue(robot.sensors[3].getDistValue()) + "," + robot.normalizeSensorValue(robot.sensors[4].getDistValue()) + "," + robot.normalizeSensorValue(robot.sensors[5].getDistValue()) + ",");
					out.write(robot.normalizeSpeed(leftSpeed) + "," + robot.normalizeSpeed(rightSpeed) + "\n");
				} catch (Exception e){}
			} 
			poc++;
		} else {
			robot.stopLearn();
			createNN();
		}
    }
    
    public void stopLearn() {
		alive = false;
	}
	
	
}
